import numpy as np
import utils.fmodule
from utils import fmodule
from utils.backdoor import *
import copy
import os
import utils.fflow as flw
import math
import collections
import torch
from tqdm import tqdm
import pickle
from torchvision import transforms
import random
from collections import OrderedDict
import time
import torch.optim as optim
class UnlearnBasicServer:
    def __init__(self, option, model, clients, data_loader, device=None):
        self.name = option['algorithm']  # 算法名称就是server名称
        self.bd = option['bd']
        self.model = model
        if device is None:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.model.to(self.device)
        self.old_model = copy.deepcopy(self.model)
        self.dataloader = data_loader
        self.current_rounds = 1  # 当前通信轮次
        # clients configuration
        self.clients = clients  # 所有clients对象
        self.clients_id = [client.id for client in self.clients]  # 登记client id
        # for c in self.clients: c.device = self.device 初始化的时候已经做过了
        self.num_clients = len(self.clients)

        self.local_data_vols = [c.datavol for c in self.clients]
        self.local_test_vols = [c.test_datavol for c in self.clients]
        self.total_data_vol = sum(self.local_data_vols)

        # TODO：后面可能会优化这里，检查逻辑
        self.selected_clients = []

        # hyper-parameters during training process
        self.num_rounds = option['num_rounds']
        self.lr = option['learning_rate']
        for c in self.clients:
            c.lr = self.lr # lr初始化
            setattr(c, "device", device)

        self.lr_decay = option['lr_decay']
        self.lr_scheduler_type = option['lr_scheduler']
        self.clients_per_round = max(int(self.num_clients * option['P']), 1)
        self.local_momentum = option['momentum']
        # TODO optimizer加入weight_decay
        self.weight_decay = option['weight_decay']

        # uniform
        # self.sample_option = option['sample']
        self.aggregation_option = option['aggregate']

        # algorithm-dependent parameters
        self.algo_para = {}

        # all options
        self.option = option

        self.stage = None
        self.save_folder = os.path.join('output',
                                        option['dataloader'],
                                        option['model'],
                                        'FU',
                                        'unlearn_clients' + str(option['u_clients']),
                                        option['algorithm'])

        # log
        self.out_log = ''
        if not os.path.exists(self.save_folder):
            os.makedirs(self.save_folder)

        # stage 判定
        # model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(self.save_folder))), 'FL', 'fedavg', 'pretrain.pth')
        if self.name != 'retrain':
            model_path = os.path.join(os.path.dirname(self.save_folder), 'fedavg', f"pretrain_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.pth")
        else:
            model_path = os.path.join(os.path.dirname(self.save_folder), 'retrain', f"s{self.option['split_num']}_c{self.option['class_num']}.pth")

        recreate = False  # True only for debug

        # 如果存在模型，且不需要重训练则开始Unlearn阶段
        if os.path.isfile(model_path) and not recreate:
            self.stage = 'Unlearn'
            pretrain_info = torch.load(model_path, map_location=self.device)
            self.model.load_state_dict(pretrain_info['model_state_dict'])
            try:
                self.pretrain_domain_var = pretrain_info['domian_var']
            except:
                self.pretrain_domain_var = pretrain_info['domain_var']

            # self.pretrain_domain_var = 0

            # self.pretrain_domain_var = 0
            # TODO: pretrained info 中还有 lr，momentum，weight_decay可以使用
            self.u_rounds = option['u_rounds']
            self.p_rounds = option['p_rounds']
            print(f'Unlearning stage, U_rounds is {self.u_rounds}, Current model is loaded from {model_path}')
        # 如果不存在模型，开始pretrain阶段
        else:
            self.stage = 'Pretrain'
            self.pretrain_domain_var = 0
            print(f"Pretrained model is not founded. Let's pretrain first")


        # save_name
        self.save_name = str(option['seed']) + '_' + option['dataloader'] + ' bs' + str(option['batch_size']) + '_' + \
                         option['algorithm'] + '_lr_' \
                         + str(option['learning_rate']) + '_decay_' + str(option['lr_decay']) + \
                         '_mo_' + str(option['momentum']) + ' agg ' + str(option['aggregate']) + \
                        ' split ' + str(option['split_num'])+ ' class ' + str(option['class_num']) + ' bd ' + str(option['bd'])

        if option['dir_a'] != -1:
            self.save_folder += ' dir_a ' + str(option['dir_a'])
        # 初始化unlearn clients
        self.unlearn_clients = [self.clients[4]] # random.sample(self.clients, option['u_clients'])  # 采样unlearn clients
        # self.unlearn_clients = [self.clients[3]]
        self.unlearn_clients_id = [c.id for c in self.unlearn_clients]
        self.test_domain_data_vol = {}
        for c in self.clients:
            self.out_log += f'Client {c.id} local statistics: {c.id}: {c.data_name} \n'
            print(f'Client {c.id} local statistics: ', {c.id: c.data_name})
            for k, v in c.data_name.items():
                if k in self.test_domain_data_vol:
                    self.test_domain_data_vol[k] += v[1]
                else:
                    self.test_domain_data_vol[k] = v[1]

        unlearn_statistics = {uid: self.clients[uid].data_name for uid in self.unlearn_clients_id}
        self.out_log += f'Unlearn Statistics: {unlearn_statistics} \n \n'
        print(f'Unlearn Statistics: {unlearn_statistics}')
        for c in self.clients:
            if c.id in self.unlearn_clients_id:
                setattr(c, 'unlearn', True)
                self.modify_client(c)
            else:
                setattr(c, 'unlearn', False)

            setattr(c, "bd", self.bd)
            setattr(c, 'stage', self.stage)
            setattr(c, 'train_data', torch.utils.data.DataLoader(c.train_dataset, shuffle=True, batch_size=option['batch_size'], num_workers=6, pin_memory=True))
            setattr(c, 'test_data', torch.utils.data.DataLoader(c.test_dataset, shuffle=False, batch_size=option['batch_size'], num_workers=6, pin_memory=True))

    # TODO: 考虑做一个直接unlearn的，不带backdoor
    def modify_client(self, client):
        # # 为client新增local_backdoor_test_data和local_backdoor_test_number属性，用以计算ASR指标。


        bd_maker = FigRandBackdoor(dataloader=self.dataloader, save_folder=self.save_folder)
        setattr(client, "bd_maker", bd_maker)
        # dataload = torch.utils.data.DataLoader(client.train_dataset, shuffle=False, batch_size=self.option['batch_size'],
        #                             num_workers=8)
        # for batch in dataload:
        #     print(batch)
        #     assert 1==0
        # client.train_dataset = BD_Dataset(client.train_dataset)
        # client.train_dataset.add_trigger(bd_maker, attack_portion=1.0)
        # assert 1==0
        #
        # # client.train_data = torch.utils.data.DataLoader(client.train_dataset, batch_size=64)
        # # client.test_dataset.add_trigger(bd_maker, attack_portion=1.0)

        # UM_test与traindata相同用于测试遗忘效果，backdoor用于测试ASR
        setattr(client, "UM_test_data", torch.utils.data.DataLoader(client.train_dataset, shuffle=False, batch_size=self.option['batch_size'], num_workers=6))
        setattr(client, "UM_test_datavol", client.training_number)

    def outFunc(self, t_metric, ):
        """
        t_metric: {'retain_accuracy': [], 'retain_loss': [],
                    'Backdoor_accuracy': [], 'Backdoor_loss': [],
                    'Unlearn_Memory_accuracy': [], 'Unlearn_Memory_loss': []
                    'domain_metric': []}
        """
        # TODO: 修复unlearn_clients domain metrics的bug
        clients_results = t_metric['domain_metric']
        UM_results = t_metric['UM_domain_metirc']

        # 初始化一个空字典来存储合并后的结果
        domain_results = {}

        # 遍历每个字典并合并
        for cid, d in enumerate(clients_results):
            if cid not in self.unlearn_clients_id:
                for key, value in d.items():
                    key = self.dataloader.domain_dict.int2str(key)
                    if key in domain_results:
                        domain_results[key] += value
                    else:
                        domain_results[key] = value


        for k, v in domain_results.items():
            domain_results[k] = f'{(v / self.test_domain_data_vol[k])*100:.2f}'
        list_results = []
        for k, v in domain_results.items():
            list_results.append(float(v))
        list_results = np.array(list_results)
        retain_acc = np.array(t_metric['retain_accuracy'])
        retain_loss = np.array(t_metric['retain_loss'])

        BD_acc = np.array(t_metric['Backdoor_accuracy'])
        BD_loss = np.array(t_metric['Backdoor_loss'])

        UM_acc = np.array(t_metric['Unlearn_Memory_accuracy'])
        UM_loss = np.array(t_metric['Unlearn_Memory_loss'])

        # 计算fairness
        def cal_fairness(values):
            p = np.ones(len(values))
            fairness = np.arccos(values @ p / (np.linalg.norm(values) * np.linalg.norm(p)))
            return fairness
        unlearned_client_fairness = cal_fairness(BD_acc)
        retained_client_fairness = cal_fairness(retain_acc)

        # write log
        out_log = ""
        out_log += f'Stage: {self.stage}, Round: {self.current_rounds}, Lr: {self.lr}' + '\n'
        if len(self.unlearn_clients_id):
            out_log += f'Unlearn Clients ID: {self.unlearn_clients_id}' + '\n'
        # out_log += f'Unlearned Client Mean Global Test loss: {format(np.mean(BD_loss), ".6f")}' + '\n' if len(BD_loss) > 0 else ''
        # TODO: 注意这里没有取5%或者10%，直接取了最大最小
        out_log += f'Unlearned Client Local Test Acc: {format(np.mean(BD_acc), ".3f")}({format(np.std(BD_acc), ".3f")}), angle: {format(unlearned_client_fairness, ".6f")}, min: {format(np.min(BD_acc), ".6f")}, max: {format(np.max(BD_acc), ".6f")}' + '\n' if len(BD_acc) > 0 else ''
        out_log += f'Unlearn Memory Local Test Acc: {format(np.mean(UM_acc), ".3f")}({format(np.std(UM_acc), ".3f")}), min: {format(np.min(UM_acc), ".6f")}, max: {format(np.max(UM_acc), ".6f")}' + '\n' if len(UM_acc) > 0 else ''
        out_log += f'Retained Client Mean Global Test loss: {format(np.mean(retain_loss), ".6f")}' + '\n' if len(retain_loss) > 0 else ''
        out_log += f'Retained Client Local Test Acc: {format(np.mean(retain_acc), ".3f")}({format(np.std(retain_acc), ".3f")}), angle: {format(retained_client_fairness, ".6f")}, min: {format(np.min(retain_acc), ".6f")}, max: {format(np.max(retain_acc), ".6f")}' + '\n'
        out_log += f'System Results: Retain {retain_acc}, Unlearn {BD_acc}' + '\n'
        out_log += f'Domain Results: {domain_results}' + '\n'
        if self.stage != 'Pretrain':
            out_log += f'Domain Equitability: {format(np.abs(np.var(list_results) - float(self.pretrain_domain_var)), ".2f")}' + '\n'
        else:
            self.pretrain_domain_var = np.var(list_results)
        # TODO：可以考虑记录时间
        # out_log += f'communication_time: {alg.communication_time}, computation_time: {alg.computation_time} \n'
        out_log += '\n'
        self.out_log = self.out_log + out_log
        print(out_log)

    def save_log(self, stream_log):
        if self.stage == 'Pretrain':
            file_name = os.path.join(self.save_folder, self.save_name + '.log')
        else:
            file_name = os.path.join(self.save_folder, 'Unlearn' + self.save_name + '.log')
        fileObject = open(file_name, 'w')
        fileObject.write(stream_log)
        fileObject.close()

    # def save_loss(self, loss_list):
    #     file_name = self.save_folder + self.save_name + '.log'
    #     with open(file_name, 'wb') as file:
    #         pickle.dump(loss_list, file)

    def run(self):
        self.current_rounds = 0
        # test_metric = self.test_on_clients(dataflag='test', model=self.model)
        # self.outFunc(t_metric=test_metric)

        if self.stage == 'Pretrain':
            for round in tqdm(range(1, self.num_rounds + 1), desc='Pretraining Rounds'):
                self.current_rounds = round
                # federated train
                self.iterate()

                # 在这里save防止global broadcast影响结果
                self.pretrain_save()

                # syn
                self.global_lr_scheduler(self.num_rounds)

                test_metric = self.test_on_clients(dataflag='test', model=self.model)
                self.outFunc(test_metric)
                self.save_log(self.out_log)
            self.save_ckp()

        if self.stage == 'Unlearn':
            # set client attr

            for round in tqdm(range(1, self.u_rounds + 1), desc='Unlearning Rounds'):
                self.current_rounds = round
                # federated unlearn
                self.unlearn_iterate()  # including global model update
                # syn
                self.global_lr_scheduler(self.num_rounds)

                test_metric = self.test_on_clients(dataflag='test', model=self.model)
                self.outFunc(test_metric)
                self.save_log(self.out_log)

            self.stage = 'PT'
            for round in tqdm(range(1, self.p_rounds + 1), desc='Post-training Rounds'):
                self.current_rounds = round
                # federated post training
                self.pt_iterate()
                # syn
                # syn
                self.global_lr_scheduler(self.p_rounds)

                test_metric = self.test_on_clients(dataflag='test', model=self.model)
                self.outFunc(test_metric)
                self.save_log(self.out_log)
            self.save_ckp()

    def iterate(self):
        # self.selected_clients = self.sample()
        self.selected_clients = range(len(self.clients))


        reply = self.communicate(self.selected_clients)
        # 按照self.selected_clients = self.received_clients
        models, losses = reply['model'], reply['loss']

        self.model = self.aggregate(models)
        del models
        return

    def unlearn_iterate(self):
        # raise RuntimeError('error in Algorithm: This function must be rewritten in the child class. (该函数必须在子类中被重写！)')
        self.selected_clients = self.sample()
        # self.selected_clients = range(len(self.clients))
        begin = time.time()
        reply = self.communicate(self.selected_clients)

        # 按照self.selected_clients = self.received_clients
        models, losses = reply['model'], reply['loss']

        self.model = self.aggregate(models)
        end = time.time()
        assert 1==0
        del models
        return

    def pt_iterate(self):
        # raise RuntimeError('error in Algorithm: This function must be rewritten in the child class. (该函数必须在子类中被重写！)')
        # Remove unlearned clients
        self.selected_clients = self.sample()
        # for uid in self.unlearn_clients_id:
        #     if uid in self.selected_clients:
        #         self.selected_clients.remove(uid)
        self.selected_clients = np.delete(self.selected_clients, np.where(np.isin(self.selected_clients, self.unlearn_clients_id))[0]) # TODO: 可能问题在这

        # self.selected_clients = np.delete(self.selected_clients, np.where(self.selected_clients == self.unlearn_clients_id)[0]) # TODO: 可能问题在这
        reply = self.communicate(self.selected_clients)
        # 按照self.selected_clients = self.received_clients
        models, losses = reply['model'], reply['loss']

        self.model = self.aggregate(models)
        del models
        return

    def pretrain_save(self):
        if self.current_rounds == 1:
            self.update_his = OrderedDict()        # for FedKDU
            for client in self.unlearn_clients:
                self.update_his[client.id] = []

        self.pretrain_KDU_save()
        self.pretrain_eraser_save()
        self.pretrain_recovery_save()
        if self.current_rounds == self.num_rounds:
            rcy_folder = os.path.join(self.save_folder, 'pretrained_history_fedrecovery')
            kdu_folder = os.path.join(self.save_folder, 'pretrained_history_fedkdu')
            era_folder = os.path.join(self.save_folder, 'pretrained_history_federaser')

            save_name = os.path.join(rcy_folder,
                                     f"global_model_norm_square_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.csv")
            np.savetxt(save_name, np.array(self.global_model_norm_square), delimiter=',')

            with open(os.path.join(kdu_folder, f"kdu_his_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.pkl"),
                      'wb') as f:
                pickle.dump(self.update_his, f)

                # TODO: 重新定义位置
                with open(os.path.join(era_folder,
                                       f"eraser_his_five_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.pkl"),
                          'wb') as f:
                    pickle.dump(self.eraser_his_five, f)

    def pretrain_recovery_save(self):
        """
        专门为FedRecovery保存client的历史数据。
        需要在pretrain阶段存储论文伪代码的 ||▽F(w_i)||^2 * lr * ▽f_iu(w_i))。
        另外还要存储历代global model的norm_square
        """
        folder = os.path.join(self.save_folder, 'pretrained_history_fedrecovery')
        if not os.path.exists(folder):
            os.makedirs(folder)
        if self.current_rounds == 1:
            client_mid_value_dict = {}
            for client in self.clients:
                client_mid_value_dict[client.id] = 0  # 初始化
            setattr(self, "client_mid_value_dict", client_mid_value_dict)
            setattr(self, "global_model_norm_square", [])  # 存储历代的global model norm square

        # 保存global_model_norm_square
        norm_square = float(torch.norm(torch.nn.utils.parameters_to_vector(self.model.parameters())).to('cpu') ** 2)
        self.global_model_norm_square.append(norm_square)

        # 下面保存||▽F(w_i)||^2 * lr * ▽f_iu(w_i))
        for client in self.clients:
            with torch.no_grad():  # 避免计算图的累积
                upload_model = client.model
                # 这个地方需要定义一下
                update = (torch.nn.utils.parameters_to_vector(self.old_model.parameters()) - torch.nn.utils.parameters_to_vector(upload_model.parameters())).to('cpu')
                mid_value = update * norm_square  # 要保存的内容
                self.client_mid_value_dict[client.id] += mid_value
                del upload_model
                del update
                del mid_value
                torch.cuda.empty_cache()

                # 保存
                if self.current_rounds % 50 == 0:
                    save_name = os.path.join(folder, 'client_' + str(client.id) + f"_mid_value_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.csv")
                    np.savetxt(save_name, self.client_mid_value_dict[client.id].to("cpu").detach().numpy(),
                               delimiter=',')

    def pretrain_KDU_save(self):
        folder = os.path.join(self.save_folder, 'pretrained_history_fedkdu')
        if not os.path.exists(folder):
            os.makedirs(folder)
        with torch.no_grad():  # 避免计算图的累积

            for client in self.unlearn_clients:
                up = client.model - self.model
                assert client.id in self.update_his.keys()
                self.update_his[client.id].append({k: v.cpu() for k, v in up.state_dict().items()})

                del up
                torch.cuda.empty_cache()


            # with open(self.file_path, 'rb') as f:
            #    self.update_his = pickle.load(f)

    def pretrain_eraser_save(self):
        folder = os.path.join(self.save_folder, 'pretrained_history_federaser')
        if not os.path.exists(folder):
            os.makedirs(folder)
        if self.current_rounds == 1:
            self.eraser_his_five = {} # save norm of update per 5 rounds delta==5
            self.last_model_five = copy.deepcopy(self.model).to('cpu')
            for client in self.clients:
                self.eraser_his_five[client.id] = []
        else:
            if self.current_rounds % 5 == 0:
                with torch.no_grad():  # 避免计算图的累积
                    for client in self.clients:
                        new_model = client.model
                        update_norm = (torch.norm(torch.nn.utils.parameters_to_vector(new_model.parameters()).to('cpu') - torch.nn.utils.parameters_to_vector(self.last_model_five.parameters()))).to('cpu')
                        self.eraser_his_five[client.id].append(update_norm)
                    torch.cuda.empty_cache()

                self.last_model_five = copy.deepcopy(self.model).to('cpu')

    def save_ckp(self):
        # save ckp
        if self.stage == 'Pretrain':
            file_name = os.path.join(self.save_folder, f"pretrain_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.pth")
        elif self.stage == 'Unlearn':
            file_name = os.path.join(self.save_folder, f"unlearn_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.pth")
        elif self.stage == 'PT':
            file_name = os.path.join(self.save_folder, f"pt_s{self.option['split_num']}_c{self.option['class_num']}_bd_{str(self.bd)}.pth")
        else:
            raise RuntimeError('Double Check The FL Stage. (请检查FL阶段是否正确！)')
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'lr': self.lr,
            'momentum': self.local_momentum,
            'weight_decay': self.weight_decay,
            'domain_var': self.pretrain_domain_var,
        }, file_name)
        print(f'{self.stage} model is saved.')
        return

    def communicate(self, selected_clients, o_model=None):
        packages_received_from_clients = []
        client_package_buffer = {}
        # ascent list
        for cid in selected_clients: client_package_buffer[cid] = None
        # computing iteratively
        for communicate_client_id in selected_clients:
            response_from_client_id = self.communicate_with(communicate_client_id, o_model)
            packages_received_from_clients.append(response_from_client_id)
        # packages_received_from_clients 升序
        for i, cid in enumerate(selected_clients): client_package_buffer[cid] = packages_received_from_clients[i]
        packages_received_from_clients = [client_package_buffer[cid] for cid in selected_clients if
                                          client_package_buffer[cid]]
        self.received_clients = selected_clients
        return self.unpack(packages_received_from_clients)

    # TODO：这里设计重复可以优化
    def communicate_with(self, communicate_client_id, o_model=None):
        if o_model is None:
            svr_pkg = self.pack(communicate_client_id)
        else:
            svr_pkg = self.pack(communicate_client_id, o_model)
        return self.clients[communicate_client_id].reply(svr_pkg)

    def pack(self, client_id, model=None):
        if model is not None:
            return {
                "model": copy.deepcopy(model),
                "current_rounds": self.current_rounds,
                "lr": self.lr,
                "momentum": self.local_momentum,
                "weight_decay": self.weight_decay,
                "stage": self.stage,
            }
        else:
            return {
                "model": copy.deepcopy(self.model),
                "current_rounds": self.current_rounds,
                "lr": self.lr,
                "momentum": self.local_momentum,
                "weight_decay": self.weight_decay,
                "stage": self.stage,
            }

    def unpack(self, packages_received_from_clients):
        if len(packages_received_from_clients) == 0: return collections.defaultdict(list)
        res = {pname: [] for pname in packages_received_from_clients[0]}
        for cpkg in packages_received_from_clients:
            for pname, pval in cpkg.items():
                res[pname].append(pval)
        return res

    def global_lr_scheduler(self, current_round):
        # print('lr scheduler type:', self.lr_scheduler_type)
        # assert 1==0
        if self.lr_scheduler_type == -1:
            return
        elif self.lr_scheduler_type == 0:
            """eta_{round+1} = DecayRate * eta_{round}"""
            self.lr *= self.lr_decay
            # for c in self.clients:
            #     c.set_learning_rate(self.lr)
        # todo: if using type 1, define 1 more self.ini_lr
        # elif self.lr_scheduler_type == 1:
        #     """eta_{round+1} = eta_0/(round+1)"""
        #     self.lr = self.option['learning_rate'] * 1.0 / (current_round + 1)
        #     for c in self.clients:
        #         c.set_learning_rate(self.lr)

    def sample(self):
        all_clients_id = self.clients_id
        selected_clients = np.random.choice(all_clients_id, min(self.clients_per_round, len(all_clients_id)),
                                            replace=False)
        return selected_clients

    def aggregate(self, models: list, *args, **kwargs):
        """
        -------------------------------------------------------------------------------------------------------------------------
         weighted_scale                 |uniform (default)          |weighted_com (original fedavg)   |other
        ==========================================================================================================================
        N/K * Σpk * model_k             |1/K * Σmodel_k             |(1-Σpk) * w_old + Σpk * model_k  |Σ(pk/Σpk) * model_k
        """
        if len(models) == 0: return self.model
        if self.aggregation_option == 'weighted_scale':
            K = len(models)
            if self.stage == 'PT':
                N = self.num_clients - len(self.unlearn_clients_id)
                retain_data_vols = [self.local_data_vols[cid] for cid in self.received_clients]
                p = [1.0 * rdv / sum(retain_data_vols) for rdv in retain_data_vols]
                print(N, p)
            else:
                N = len(self.selected_clients)
                selected_data_vols = [self.local_data_vols[cid] for cid in self.received_clients]
                p = [1.0 * self.local_data_vols[cid] / sum(selected_data_vols) for cid in self.received_clients]
            print(f'Aggregation rate: {p}')
            return fmodule._model_sum([model_k * pk for model_k, pk in zip(models, p)]) * N / K

        elif self.aggregation_option == 'uniform':
            return fmodule._model_average(models)
        elif self.aggregation_option == 'weighted_com':
            p = [1.0 * self.local_data_vols[cid] / self.total_data_vol for cid in self.received_clients]
            w = fmodule._model_sum([model_k * pk for model_k, pk in zip(models, p)])
            return (1.0 - sum(p)) * self.model + w
        else:
            p = [1.0 * self.local_data_vols[cid] / self.total_data_vol for cid in self.received_clients]
            sump = sum(p)
            p = [pk / sump for pk in p]
            return fmodule._model_sum([model_k * pk for model_k, pk in zip(models, p)])

    def test_on_clients(self, dataflag='test', model=None):
        all_metrics = collections.defaultdict(list)
        test_model = self.model if model is None else model
        for cid, c in enumerate(self.clients):
            test_model = self.model if model is None else model
            client_metrics = c.test(test_model, dataflag)
            for met_name, met_val in client_metrics.items():
                all_metrics[met_name].append(met_val)
        return all_metrics

    def init_algo_para(self, algo_para: dict):
        self.algo_para = algo_para
        if len(self.algo_para) == 0: return
        if self.option['algo_para'] is not None:
            keys = list(self.algo_para.keys())
            for i, pv in enumerate(self.option['algo_para']):
                if i == len(self.option['algo_para']): break
                para_name = keys[i]
                self.algo_para[para_name] = type(self.algo_para[para_name])(pv)
        # register the algorithm-dependent hyperparameters as the attributes of the server and all the clients
        for para_name, value in self.algo_para.items():
            self.__setattr__(para_name, value)
            for c in self.clients:
                c.__setattr__(para_name, value)
        return

    def data_to_device(self, data, device):
        new_data = None
        if type(data) == torch.Tensor:
            new_data = data.to(device)
        elif type(data) == tuple:
            new_data = []
            for item in data:
                item = item.to(device)
                new_data.append(item)
            new_data = tuple(new_data)
        elif type(data) == list:
            new_data = []
            for item in data:
                item = item.to(device)
                new_data.append(item)
        return new_data


class UnlearnBasicClient():
    def __init__(self, option, id, model=None, ):
        self.id = id
        # hyper-parameters for training
        self.optimizer_name = option['optimizer']
        self.lr = option['learning_rate']
        self.batch_size = option['batch_size']
        self.momentum = option['momentum']
        self.weight_decay = option['weight_decay']
        self.epochs = option['num_epochs']
        self.model = model
        self.current_steps = 0

        # server
        self.current_rounds = 0
        self.criterion = torch.nn.CrossEntropyLoss()

        # unlearn


    def update_data(self, id, local_training_data, local_training_number, local_test_data, local_test_number, data_name):
        self.id = id
        self.train_dataset = local_training_data
        self.training_number = local_training_number
        self.test_dataset = local_test_data
        self.test_number = local_test_number
        self.datavol = self.training_number
        self.test_datavol = self.test_number
        self.data_name = data_name
        self.num_steps = self.epochs * math.ceil(len(self.train_dataset) / self.batch_size)

    def js(self, p_output, q_output):
        KLDivLoss = torch.nn.KLDivLoss(reduction='mean')
        log_mean_output = ((p_output + q_output) / 2).log()
        return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output)) / 2

    def data_to_device(self, data, device):
        new_data = None
        if type(data) == torch.Tensor:
            new_data = data.to(device)
        elif type(data) == tuple:
            new_data = []
            for item in data:
                item = item.to(device)
                new_data.append(item)
            new_data = tuple(new_data)
        elif type(data) == list:
            new_data = []
            for item in data:
                item = item.to(device)
                new_data.append(item)
        return new_data

    def cal_gradient_loss(self, ):
        sm = torch.nn.Softmax(dim=1)
        lsm = torch.nn.LogSoftmax(dim=1)
        self.model.train()
        optimizer = self.get_optimizer()
        weights = []
        grad_mat = []
        total_loss = 0
        total_training_number = 0
        # for step, (batch_x, batch_y) in enumerate(self.train_data):
        for batch_id, batch_data in enumerate(self.train_data):
            batch_x, batch_y = batch_data['image'], batch_data['label']
            if self.LSR is False:
                weights.append(batch_y.shape[0])
                self.model.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                outputs = self.model(batch_x)
                loss = self.criterion(outputs, batch_y)
                loss.backward()
                # if self.dp:
                #     self.add_dp_noise(len(batch_y), self.model)
                grad_vec = fmodule._grad2vec(self.model)
                grad_mat.append(grad_vec)
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)
            else:
                weights.append(batch_y.shape[0])
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                # print(batch_x.shape)
                # assert 1==0
                # if 'fashion' in self.save_name:
                #     batch_x_aug = self.tt_transform(batch_x.view(-1, 784))  ##.view(-1, 784)
                # else:
                #     batch_x_aug = self.tt_transform(batch_x.view(-1, 3, 32, 32))  ##.view(-1, 784)
                batch_x_aug = self.tt_transform(batch_x)
                batch_x_aug = self.data_to_device(batch_x_aug, self.device)
                output1 = self.model(batch_x)
                output2 = self.model(batch_x_aug)
                mix_1 = np.random.beta(1, 1)  # mixing predict1 and predict2
                mix_2 = 1 - mix_1
                logits1, logits2 = torch.softmax(output1 * 3, dim=1), torch.softmax(output2 * 3, dim=1)

                logits1, logits2 = torch.clamp(logits1, min=1e-6, max=1.0), torch.clamp(logits2, min=1e-6, max=1.0)
                L_e = - (torch.mean(torch.sum(sm(logits1) * lsm(logits1), dim=1)) + torch.mean(
                    torch.sum(sm(logits1) * lsm(logits1), dim=1))) * 0.5
                p = torch.softmax(output1, dim=1) * mix_1 + torch.softmax(output2, dim=1) * mix_2
                pt = p ** (2)
                pred_mix = pt / pt.sum(dim=1, keepdim=True)
                betaa = self.gamma
                if (self.current_rounds < self.global_t_w):
                    betaa = self.gamma * self.current_rounds / self.global_t_w

                loss = self.criterion(pred_mix, batch_y)  # to compute cross entropy loss
                loss += self.js(logits1, logits2) * betaa
                loss += L_e * self.lambda_e
                self.model.zero_grad()
                loss.backward()
                grad_vec = fmodule._grad2vec(self.model)
                grad_mat.append(grad_vec)
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * batch_y.shape[0]
        loss = total_loss / self.datavol
        weights = torch.Tensor(weights).float().to(self.device)
        weights /= torch.sum(weights)
        grad_mat = torch.stack([grad for grad in grad_mat]).to(self.device)
        grad = weights @ grad_mat
        return grad, loss

    # def make_backdoor_batch(self, bd_batch_x, bd_batch_y, p=0.5, s='train'):
    #     if self.mask is None and self.pattern is None:
    #         full_image = torch.zeros(bd_batch_x[0].shape).fill_(self.mask_value)
    #         full_image[:, self.x_top:self.x_bot, self.y_top:self.y_bot] = self.pattern_tensor
    #         self.mask = 1 * (full_image != self.mask_value)
    #         self.pattern = full_image
    #         if 'cifar10' in self.task_name:
    #             means = (0.4914, 0.4822, 0.4465)
    #             lvars = (0.2023, 0.1994, 0.2010)
    #             normalize = transforms.Normalize(means, lvars)
    #             self.pattern = normalize(self.pattern)
    #
    #     attack_portion = round(len(bd_batch_y) * p) if s == 'train' else len(bd_batch_y)
    #     bd_batch_x[:attack_portion] = (1 - self.mask) * bd_batch_x[:attack_portion] + self.mask * self.pattern
    #     bd_batch_y[:attack_portion].fill_(self.backdoor_label)
    #     return bd_batch_x, bd_batch_y
    def train(self, ):
        self.model.train()
        total_loss = 0.0
        optimizer = self.get_optimizer(self.model)
        for e in range(self.epochs):
            # for step, (batch_x, batch_y) in enumerate(self.train_data):
            for batch_id, batch_data in enumerate(self.train_data):
                batch_x, batch_y = batch_data['image'], batch_data['label']
                if self.unlearn and self.bd:
                    batch_x, batch_y = self.bd_maker.add_backdoor(batch_x, batch_y)
                self.model.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                outputs = self.model(batch_x)
                loss = self.criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)
        del optimizer
        return total_loss / (self.datavol * self.epochs)


    def unlearn(self, ):
        raise RuntimeError('This function must be rewritten in the child class. ')

    def post_train(self, ):
        raise RuntimeError('This function must be rewritten in the child class. ')
    def test(self, model=None, dataflag='test'):
        # TODO: 这里有显存可以优化
        test_model = model if model is not None else self.model
        test_model.eval()

        if dataflag == 'train':
            dataset = self.train_data
            datavol = self.datavol
        else:
            dataset = self.test_data
            datavol = self.test_datavol
        total_loss = 0.0
        num_correct = 0
        local_metric = {}
        correct_by_domain = {}
        debug_domain = {}
        with torch.no_grad():
            # for batch_id, (batch_x, batch_y) in enumerate(dataset):
            for batch_id, batch_data in enumerate(dataset):
                batch_x, batch_y, batch_d = batch_data['image'], batch_data['label'], batch_data['domain']
                if self.unlearn and self.bd:
                    batch_x, batch_y = self.bd_maker.add_backdoor(batch_x, batch_y)
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)

                outputs = test_model(batch_x)
                batch_mean_loss = self.criterion(outputs, batch_y).item()
                y_pred = outputs.data.max(1, keepdim=True)[1]
                for dn in torch.unique(batch_d):
                    dn = dn.item()
                    domain_indices = torch.where(batch_d == dn)
                    if dn in debug_domain:
                        debug_domain[dn] += len(domain_indices[0])
                    else:
                        debug_domain[dn] = len(domain_indices[0])

                    if dn in correct_by_domain:
                        correct_by_domain[dn] += y_pred[domain_indices].eq(batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                    else:
                        correct_by_domain[dn] = y_pred[domain_indices].eq(batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                correct = y_pred.eq(batch_y.data.view_as(y_pred)).long().cpu().sum()
                num_correct += correct.item()
                total_loss += batch_mean_loss * len(batch_y)
            if not self.unlearn:
                local_metric.update({'retain_accuracy': 100 * num_correct / datavol, 'retain_loss': total_loss / datavol,
                                     'domain_metric': correct_by_domain})
            else:
                local_metric.update({'Backdoor_accuracy': 100 * num_correct / datavol,
                                    'Backdoor_loss': total_loss / datavol, 'domain_metric': correct_by_domain})
        if self.unlearn:
            # 统计unlearn memory acc 用训练集做指标的
            BD_correct = 0
            BD_loss = 0.0
            correct_by_domain = {}
            with torch.no_grad():
                # for batch_id, (batch_x, batch_y) in enumerate(self.UM_test_data):
                for batch_id, batch_data in enumerate(self.UM_test_data):
                    batch_x, batch_y, batch_d = batch_data['image'], batch_data['label'], batch_data['domain']
                    if self.unlearn and self.bd:
                        batch_x, batch_y = self.bd_maker.add_backdoor(batch_x, batch_y)
                    batch_x = self.data_to_device(batch_x, device=self.device)
                    batch_y = self.data_to_device(batch_y, device=self.device)

                    outputs = test_model(batch_x)
                    batch_mean_loss = self.criterion(outputs, batch_y).item()
                    y_pred = outputs.data.max(1, keepdim=True)[1]
                    for dn in torch.unique(batch_d):
                        dn = dn.item()
                        domain_indices = torch.where(batch_d == dn)
                        if dn in correct_by_domain:
                            correct_by_domain[dn] += y_pred[domain_indices].eq(
                                batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                        else:
                            correct_by_domain[dn] = y_pred[domain_indices].eq(
                                batch_y[domain_indices].view_as(y_pred[domain_indices])).long().cpu().sum().item()
                    correct = y_pred.eq(batch_y.data.view_as(y_pred)).long().cpu().sum()
                    BD_correct += correct.item()
                    BD_loss += batch_mean_loss * len(batch_y)
                local_metric.update({'Unlearn_Memory_accuracy': 100 * BD_correct / self.UM_test_datavol,
                                     'Unlearn_Memory_loss': BD_loss / self.UM_test_datavol, 'UM_domain_metric': correct_by_domain})
        return local_metric
    def unpack(self, received_pkg):
        self.current_rounds = received_pkg['current_rounds']
        self.lr = received_pkg['lr']
        self.momentum = received_pkg['momentum']
        self.weight_decay = received_pkg['weight_decay']
        self.stage = received_pkg['stage']
        self.model = received_pkg['model']
        del received_pkg

    def reply(self, server_pack=None):
        assert server_pack is not None
        self.unpack(server_pack)
        del server_pack
        loss = self.train()
        cpkg = self.pack(loss)
        return cpkg

    def cal_grad_loss(self, svr_pkg):
        assert svr_pkg is not None
        self.model = self.unpack(svr_pkg)
        g, l = self.cal_gradient_loss()
        return g, l

    def pack(self, loss, model=None):
        if model is not None:
            return {
                "model": copy.deepcopy(model),
                "loss": loss,
            }
        else:
            return {
                "model": copy.deepcopy(self.model),
                "loss": loss,
            }

    def set_model(self, model):
        self.model = model

    def set_learning_rate(self, lr=None):
        self.learning_rate = lr if lr else self.learning_rate

    def update_device(self, dev):
        self.device = dev
        self.calculator = fmodule.TaskCalculator(dev, self.calculator.optimizer_name)

    def get_optimizer(self, optim_model):
        if self.optimizer_name == 'SGD':
            optimizer = optim.SGD(optim_model.parameters(), lr=self.lr, weight_decay=self.weight_decay, momentum=self.momentum)
        elif self.optimizer_name == 'Adam':
            optimizer = optim.Adam(optim_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        else:
            raise ValueError(f"Unsupported optimizer: {self.optimizer_name}")
        return optimizer